import typing as tp

import arrow

from core.models._base import BaseModel, TimestampMixin
from core.services import get_db
from core.validators._base import (
    BaseValidator as CreatingValidator,
    BaseValidator as PartialUpdatingValidator,
)


class BaseRepository():
    pass


class CRUDRepository():
    pk: str = 'id'
    pk_type: tp.Type

    def __init__(self, Model: tp.Type[BaseModel]):
        self.Model = Model

    def get_model(self):
        return self.Model
        # self.pk = self.Model.__table__.primary_key.columns.keys()[0]

    def all(self, page_size: int = 15, page_to: int = 1):
        return self.get_model().simple_paginate(page_size, page_to).serialize()

    # def _get(self, id: str):
    #     db = next(get_db())
    #     data = db.query(self.Model).get(id)
    #
    #     return data

    def detail(self, id) -> BaseModel:
        return self.Model.find(id)

    def create(self, payload: CreatingValidator):
        client = get_db()
        db = next(client)
        model = self.Model(**payload.dict())
        if isinstance(model, TimestampMixin):
            # type: model: TimestampMixin
            model.created_at = arrow.now('Asia/Shanghai').datetime
            model.updated_at = arrow.now('Asia/Shanghai').datetime
            model.created_by = 0
            model.updated_by = 0
        db.add(model)
        db.commit()
        db.refresh(model)

        return model

    def partial_update(self, id, payload: PartialUpdatingValidator):
        d = self.get_model().find_or_fail(id)
        d.update(payload.dict())
        return d
        # db = next(get_db())
        # model = db.query(self.Model).get(id)
        #
        # for key, value in payload.dict(exclude={self.pk}).items():
        #     if hasattr(model, key):
        #         setattr(model, key, value)
        #
        # if isinstance(model, TimestampMixin):
        #     # type: model: TimestampMixin
        #     model.updated_at = arrow.now('Asia/Shanghai').datetime
        #     model.updated_by = 0
        #
        # db.commit()
        # db.refresh(model)
        #
        # return model

    def delete(self, id):
        # d = self.Model.find_or_fail(id)
        d = self.Model.find(id)
        d.delete()
        return d
